{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e1b280ab-b61f-4d1a-bf7e-44e5f9ed3a5c",
   "metadata": {
    "id": "e1b280ab-b61f-4d1a-bf7e-44e5f9ed3a5c"
   },
   "source": [
    "<table style=\"width:100%\">\n",
    "<tr>\n",
    "<td style=\"vertical-align:middle; text-align:left;\">\n",
    "<font size=\"2\">\n",
    "Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
    "<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
    "</font>\n",
    "</td>\n",
    "<td style=\"vertical-align:middle; text-align:left;\">\n",
    "<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
    "</td>\n",
    "</tr>\n",
    "</table>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "efde77f2-6af3-4781-8597-89ecd3f41a52",
   "metadata": {
    "id": "efde77f2-6af3-4781-8597-89ecd3f41a52"
   },
   "source": [
    "# Olmo 3 From Scratch (A Standalone Notebook)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d",
   "metadata": {
    "id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d"
   },
   "source": [
    "- This notebook is purposefully minimal and focuses on the code to re-implement Olmo 3 7B and 32 models from Allen AI in pure PyTorch without relying on other external LLM libraries; Olmo 3 is interesting because it is currently the leading fully open-source model\n",
    "- For more information, see the official [Olmo 3 announcement](https://allenai.org/blog/olmo3) and model cards:\n",
    "  - [Olmo-3-1025-7B](https://huggingface.co/allenai/Olmo-3-1025-7B) (base model)\n",
    "  - [Olmo-3-7B-Instruct](https://huggingface.co/allenai/Olmo-3-7B-Instruct)\n",
    "  - [Olmo-3-7B-Think](https://huggingface.co/allenai/Olmo-3-7B-Think)\n",
    "- Note that there are also 32B versions, which are not listed above for brevity; you can find a complete list [here](https://huggingface.co/collections/allenai/olmo-3-post-training)\n",
    "- Below is a side-by-side comparison with Qwen3 8B as a reference model; if you are interested in the Qwen3 0.6B standalone notebook, you can find it [here](../11_qwen3)\n",
    "<br>\n",
    "\n",
    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/olmo3/olmo3.webp\">\n",
    "  \n",
    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/olmo3/olmo3-pipeline.webp\">\n",
    "  \n",
    "  \n",
    "- About the code:\n",
    "  - all code is my own code, mapping the Olmo 3 architecture onto the model code implemented in my [Build A Large Language Model (From Scratch)](http://mng.bz/orYv) book; the code is released under a permissive open-source Apache 2.0 license (see [LICENSE.txt](https://github.com/rasbt/LLMs-from-scratch/blob/main/LICENSE.txt))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7c201adb-747e-437b-9a62-442802941e01",
   "metadata": {},
   "outputs": [],
   "source": [
    "# pip install -r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/ch05/07_gpt_to_llama/requirements-extra.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df",
    "outputId": "4f762354-e0a3-4cc2-e5d4-e61a227a202c"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "huggingface_hub version: 0.35.0\n",
      "tokenizers version: 0.22.1\n",
      "torch version: 2.9.1+cu130\n"
     ]
    }
   ],
   "source": [
    "from importlib.metadata import version\n",
    "\n",
    "pkgs = [\n",
    "    \"huggingface_hub\",  # to download pretrained weights\n",
    "    \"tokenizers\",       # to implement the tokenizer\n",
    "    \"torch\",            # to implement the model\n",
    "]\n",
    "for p in pkgs:\n",
    "    print(f\"{p} version: {version(p)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "07e96fbb-8e16-4f6d-835f-c6159321280b",
   "metadata": {},
   "source": [
    "- Note that there are three model types, and each of the four model types comes in a 7B and 32B size:\n",
    "1. Base (`Olmo-3-1025-7B` and `Olmo-3-1125-32B`)\n",
    "2. Instruct (`Olmo-3-7B/32B-Think`)\n",
    "3. Reasoning (`Olmo-3-32B/7B-Think`)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "70a90338-624a-4706-aa55-6b4358070194",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Select which model to use\n",
    "\n",
    "# USE_MODEL = \"Olmo-3-1025-7B\"\n",
    "# USE_MODEL = \"Olmo-3-1125-32B\"\n",
    "USE_MODEL = \"Olmo-3-7B-Instruct\"\n",
    "# USE_MODEL = \"Olmo-3-32B-Instruct\"\n",
    "# USE_MODEL = \"Olmo-3-7B-Think\"\n",
    "# USE_MODEL = \"Olmo-3-32B-Think\"\n",
    "# USE_MODEL = \"Olmo-3-7B-RLZero-IF\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1899ab4b-e1c2-4215-b3d1-ed00d52e4576",
   "metadata": {},
   "source": [
    "- In addition to the checkpoints listed above, you can also use the intermediate checkpoints listed [here](https://huggingface.co/collections/allenai/olmo-3-post-training); since they all have the same architecture, they are all compatible with this notebook"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "653410a6-dd2b-4eb2-a722-23d9782e726d",
   "metadata": {
    "id": "653410a6-dd2b-4eb2-a722-23d9782e726d"
   },
   "source": [
    "&nbsp;\n",
    "# 1. Architecture code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "82076c21-9331-4dcd-b017-42b046cf1a60",
   "metadata": {
    "id": "82076c21-9331-4dcd-b017-42b046cf1a60"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "\n",
    "class FeedForward(nn.Module):\n",
    "    def __init__(self, cfg):\n",
    "        super().__init__()\n",
    "        self.fc1 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
    "        self.fc2 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
    "        self.fc3 = nn.Linear(cfg[\"hidden_dim\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x_fc1 = self.fc1(x)\n",
    "        x_fc2 = self.fc2(x)\n",
    "        x = nn.functional.silu(x_fc1) * x_fc2\n",
    "        return self.fc3(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "56715760-37e1-433e-89da-04864c139a9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "class RMSNorm(nn.Module):\n",
    "    def __init__(self, emb_dim, eps=1e-6):\n",
    "        super().__init__()\n",
    "        self.eps = eps\n",
    "        self.weight = nn.Parameter(torch.ones(emb_dim))\n",
    "\n",
    "    def forward(self, x):\n",
    "        input_dtype = x.dtype\n",
    "        x_f = x.float()\n",
    "        var = x_f.pow(2).mean(dim=-1, keepdim=True)\n",
    "        x_norm = x_f * torch.rsqrt(var + self.eps)\n",
    "        return (self.weight * x_norm).to(input_dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4b9a346f-5826-4083-9162-abd56afc03f0",
   "metadata": {
    "id": "4b9a346f-5826-4083-9162-abd56afc03f0"
   },
   "outputs": [],
   "source": [
    "def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, attention_factor=1.0, rope_type=\"default\", rope_factor=1.0, rope_orig_max=8192, dtype=torch.float32):\n",
    "    assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n",
    "\n",
    "    # Compute the inverse frequencies\n",
    "    inv_freq = 1.0 / (\n",
    "        theta_base ** (\n",
    "            torch.arange(0, head_dim, 2, dtype=dtype)[: head_dim // 2].float()\n",
    "            / head_dim\n",
    "        )\n",
    "    )\n",
    "\n",
    "    # Generate position indices\n",
    "    positions = torch.arange(context_length, dtype=dtype)\n",
    "\n",
    "    # Optional YaRN scaling\n",
    "    if rope_type == \"yarn\":\n",
    "        positions = positions / rope_factor\n",
    "        positions = torch.clamp(positions, max=rope_orig_max - 1)\n",
    "\n",
    "    # Compute the base angles (shape: [context_length, head_dim // 2])\n",
    "    angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0)\n",
    "\n",
    "    # Expand to full head_dim (shape: [context_length, head_dim])\n",
    "    angles = torch.cat([angles, angles], dim=1)\n",
    "\n",
    "    # Precompute sine and cosine\n",
    "    cos = torch.cos(angles) * attention_factor\n",
    "    sin = torch.sin(angles) * attention_factor\n",
    "\n",
    "    return cos, sin\n",
    "\n",
    "\n",
    "def apply_rope(x, cos, sin, offset=0):\n",
    "    # x: (batch_size, num_heads, seq_len, head_dim)\n",
    "    batch_size, num_heads, seq_len, head_dim = x.shape\n",
    "    assert head_dim % 2 == 0, \"Head dimension must be even\"\n",
    "\n",
    "    # Split x into first half and second half\n",
    "    x1 = x[..., : head_dim // 2]  # First half\n",
    "    x2 = x[..., head_dim // 2 :]  # Second half\n",
    "\n",
    "    # Adjust sin and cos shapes\n",
    "    cos = cos[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, seq_len, head_dim)\n",
    "    sin = sin[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0)\n",
    "\n",
    "    # Apply the rotary transformation\n",
    "    rotated = torch.cat((-x2, x1), dim=-1)\n",
    "    x_rotated = (x * cos) + (rotated * sin)\n",
    "\n",
    "    # It's ok to use lower-precision after applying cos and sin rotation\n",
    "    return x_rotated.to(dtype=x.dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb",
   "metadata": {
    "id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb"
   },
   "outputs": [],
   "source": [
    "class GroupedQueryAttention(nn.Module):\n",
    "    def __init__(self, d_in, num_heads, num_kv_groups, head_dim, attention_bias=False, dtype=None, sliding_window=None, attn_type=\"full_attention\"):\n",
    "        super().__init__()\n",
    "        assert num_heads % num_kv_groups == 0, \"num_heads must be divisible by num_kv_groups\"\n",
    "\n",
    "        self.num_heads = num_heads\n",
    "        self.num_kv_groups = num_kv_groups\n",
    "        self.group_size = num_heads // num_kv_groups\n",
    "\n",
    "        self.head_dim = head_dim\n",
    "        self.d_out = num_heads * head_dim\n",
    "        self.attn_type = attn_type\n",
    "        self.sliding_window = sliding_window if attn_type == \"sliding_attention\" else None\n",
    "\n",
    "        # Projections\n",
    "        self.W_query = nn.Linear(d_in, self.d_out, bias=attention_bias, dtype=dtype)\n",
    "        self.W_key = nn.Linear(d_in, num_kv_groups * head_dim, bias=attention_bias, dtype=dtype)\n",
    "        self.W_value = nn.Linear(d_in, num_kv_groups * head_dim, bias=attention_bias, dtype=dtype)\n",
    "        self.out_proj = nn.Linear(self.d_out, d_in, bias=attention_bias, dtype=dtype)\n",
    "\n",
    "        # Olmo3-style RMSNorm over the flattened projections\n",
    "        self.q_norm = RMSNorm(self.d_out)\n",
    "        self.k_norm = RMSNorm(num_kv_groups * head_dim)\n",
    "\n",
    "    def forward(self, x, mask, cos, sin, start_pos=0, cache=None):\n",
    "        b, num_tokens, _ = x.shape\n",
    "\n",
    "        # Apply projections\n",
    "        queries = self.W_query(x)  # (b, num_tokens, num_heads * head_dim)\n",
    "        keys = self.W_key(x)       # (b, num_tokens, num_kv_groups * head_dim)\n",
    "        values = self.W_value(x)   # (b, num_tokens, num_kv_groups * head_dim)\n",
    "\n",
    "        # Normalize q and k\n",
    "        queries = self.q_norm(queries)\n",
    "        keys_new = self.k_norm(keys)\n",
    "\n",
    "        # Reshape to (b, heads, seq, head_dim)\n",
    "        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)\n",
    "        keys_new = keys_new.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)\n",
    "        values_new = values.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)\n",
    "\n",
    "        # Cache unrotated K/V\n",
    "        prev_len = 0\n",
    "        if cache is not None:\n",
    "            prev_k, prev_v = cache\n",
    "            if prev_k is not None:\n",
    "                prev_len = prev_k.size(2)\n",
    "                keys_cat_raw = torch.cat([prev_k, keys_new], dim=2)\n",
    "                values_cat_raw = torch.cat([prev_v, values_new], dim=2)\n",
    "            else:\n",
    "                keys_cat_raw = keys_new\n",
    "                values_cat_raw = values_new\n",
    "        else:\n",
    "            keys_cat_raw = keys_new\n",
    "            values_cat_raw = values_new\n",
    "\n",
    "        # Apply RoPE with offsets for cached tokens\n",
    "        queries = apply_rope(queries, cos, sin, offset=start_pos)\n",
    "        keys = apply_rope(keys_cat_raw, cos, sin, offset=start_pos - prev_len)\n",
    "\n",
    "        # Expand KV groups to full head count\n",
    "        if self.group_size > 1:\n",
    "            keys = keys.repeat_interleave(self.group_size, dim=1)\n",
    "            values = values_cat_raw.repeat_interleave(self.group_size, dim=1)\n",
    "        else:\n",
    "            values = values_cat_raw\n",
    "\n",
    "        # Scaling before the matmul seems to be a bit more stable for Olmo\n",
    "        scale = self.head_dim ** -0.5  # Python float\n",
    "        queries = queries * scale\n",
    "\n",
    "        # Update cache with unrotated K/V\n",
    "        if cache is not None and cache[0] is not None:\n",
    "            next_cache = (\n",
    "                torch.cat([cache[0], keys_new], dim=2),\n",
    "                torch.cat([cache[1], values_new], dim=2),\n",
    "            )\n",
    "        else:\n",
    "            next_cache = (keys_new, values_new)\n",
    "\n",
    "        # Attention\n",
    "        attn_scores = queries @ keys.transpose(2, 3)\n",
    "        if mask is not None:\n",
    "            attn_scores = attn_scores.masked_fill(mask, -torch.inf)\n",
    "\n",
    "        attn_weights = torch.softmax(attn_scores, dim=-1)\n",
    "        context = (attn_weights @ values).transpose(1, 2).reshape(b, num_tokens, self.d_out)\n",
    "        out = self.out_proj(context)\n",
    "\n",
    "        return out, next_cache"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "13eb3430-0c06-4fe2-a005-217205eee21e",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TransformerBlock(nn.Module):\n",
    "    def __init__(self, cfg, attn_type):\n",
    "        super().__init__()\n",
    "        self.attn_type = attn_type\n",
    "        self.sliding_window = cfg[\"sliding_window\"]\n",
    "        self.att = GroupedQueryAttention(\n",
    "            d_in=cfg[\"emb_dim\"],\n",
    "            num_heads=cfg[\"n_heads\"],\n",
    "            num_kv_groups=cfg[\"n_kv_heads\"],\n",
    "            head_dim=cfg[\"head_dim\"],\n",
    "            attention_bias=cfg[\"attention_bias\"],\n",
    "            dtype=cfg[\"dtype\"],\n",
    "            sliding_window=cfg[\"sliding_window\"],\n",
    "            attn_type=attn_type,\n",
    "        )\n",
    "        self.ff = FeedForward(cfg)\n",
    "        self.post_attention_layernorm = RMSNorm(cfg[\"emb_dim\"], eps=cfg[\"rms_norm_eps\"])\n",
    "        self.post_feedforward_layernorm = RMSNorm(cfg[\"emb_dim\"], eps=cfg[\"rms_norm_eps\"])\n",
    "\n",
    "    def forward(self, x, mask_global, mask_local, cos, sin, start_pos=0, cache=None):\n",
    "        shortcut = x\n",
    "        if self.attn_type == \"sliding_attention\":\n",
    "            if cache is not None and isinstance(cache, tuple):\n",
    "                prev_k, _ = cache\n",
    "                prev_len = prev_k.size(2) if prev_k is not None else 0\n",
    "            else:\n",
    "                prev_len = 0\n",
    "            eff_kv_len = prev_len + x.size(1)\n",
    "            attn_mask = mask_local[..., -eff_kv_len:]\n",
    "        else:\n",
    "            attn_mask = mask_global\n",
    "\n",
    "        x_attn, next_cache = self.att(x, attn_mask, cos, sin, start_pos=start_pos, cache=cache)\n",
    "        if next_cache is not None and self.attn_type == \"sliding_attention\":\n",
    "            k, v = next_cache\n",
    "            if k.size(2) > self.sliding_window:\n",
    "                k = k[:, :, -self.sliding_window:, :]\n",
    "                v = v[:, :, -self.sliding_window:, :]\n",
    "            next_cache = (k, v)\n",
    "\n",
    "        x_attn = self.post_attention_layernorm(x_attn)\n",
    "        x = shortcut + x_attn\n",
    "\n",
    "        shortcut = x\n",
    "        x_ffn = self.ff(x)\n",
    "        x_ffn = self.post_feedforward_layernorm(x_ffn)\n",
    "        x = shortcut + x_ffn\n",
    "        return x, next_cache"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9",
   "metadata": {
    "id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9"
   },
   "outputs": [],
   "source": [
    "class Olmo3Model(nn.Module):\n",
    "    def __init__(self, cfg):\n",
    "        super().__init__()\n",
    "        assert cfg[\"layer_types\"] is not None and len(cfg[\"layer_types\"]) == cfg[\"n_layers\"]\n",
    "\n",
    "        self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"])\n",
    "        self.blocks = nn.ModuleList([TransformerBlock(cfg, attn_type) for attn_type in cfg[\"layer_types\"]])\n",
    "        self.final_norm = RMSNorm(cfg[\"emb_dim\"], eps=cfg[\"rms_norm_eps\"])\n",
    "        self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
    "        self.cfg = cfg\n",
    "        self.current_pos = 0\n",
    "\n",
    "        cos, sin = compute_rope_params(\n",
    "            head_dim=cfg[\"head_dim\"],\n",
    "            context_length=cfg[\"context_length\"],\n",
    "            theta_base=cfg[\"rope_base\"],\n",
    "            attention_factor=cfg[\"rope_attention_factor\"],\n",
    "            rope_type=cfg[\"rope_type\"],\n",
    "            rope_factor=cfg[\"rope_factor\"],\n",
    "            rope_orig_max=cfg[\"rope_orig_max\"],\n",
    "            dtype=torch.float32,\n",
    "        )\n",
    "        self.register_buffer(\"cos\", cos, persistent=False)\n",
    "        self.register_buffer(\"sin\", sin, persistent=False)\n",
    "\n",
    "    def create_masks(self, cur_len, device, pos_start=0, pos_end=None):\n",
    "        if pos_end is None:\n",
    "            pos_end = cur_len\n",
    "        total_len = pos_end\n",
    "\n",
    "        ones = torch.ones((total_len, total_len), dtype=torch.bool, device=device)\n",
    "        # mask_global_full (future is masked: j > i)\n",
    "        #     j:  0 1 2 3 4 5 6 7\n",
    "        #  i\n",
    "        #     0:  0 1 1 1 1 1 1 1\n",
    "        #     1:  0 0 1 1 1 1 1 1\n",
    "        #     2:  0 0 0 1 1 1 1 1\n",
    "        #     3:  0 0 0 0 1 1 1 1\n",
    "        #     4:  0 0 0 0 0 1 1 1\n",
    "        #     5:  0 0 0 0 0 0 1 1\n",
    "        #     6:  0 0 0 0 0 0 0 1\n",
    "        #     7:  0 0 0 0 0 0 0 0\n",
    "        mask_global_full = torch.triu(ones, diagonal=1)\n",
    "\n",
    "        # far_past (too far back is masked: i - j >= sliding_window)\n",
    "        # where sliding_window = 4\n",
    "        #     j:  0 1 2 3 4 5 6 7\n",
    "        #  i\n",
    "        #     0:  0 0 0 0 0 0 0 0\n",
    "        #     1:  0 0 0 0 0 0 0 0\n",
    "        #     2:  0 0 0 0 0 0 0 0\n",
    "        #     3:  0 0 0 0 0 0 0 0\n",
    "        #     4:  1 0 0 0 0 0 0 0\n",
    "        #     5:  1 1 0 0 0 0 0 0\n",
    "        #     6:  1 1 1 0 0 0 0 0\n",
    "        #     7:  1 1 1 1 0 0 0 0\n",
    "        far_past_full = torch.triu(ones, diagonal=self.cfg[\"sliding_window\"]).T\n",
    "\n",
    "        # Local (sliding_window) = future OR far-past\n",
    "        # mask_local\n",
    "        #     j:  0 1 2 3 4 5 6 7\n",
    "        # i\n",
    "        # 0:      0 1 1 1 1 1 1 1\n",
    "        # 1:      0 0 1 1 1 1 1 1\n",
    "        # 2:      0 0 0 1 1 1 1 1\n",
    "        # 3:      0 0 0 0 1 1 1 1\n",
    "        # 4:      1 0 0 0 0 1 1 1\n",
    "        # 5:      1 1 0 0 0 0 1 1\n",
    "        # 6:      1 1 1 0 0 0 0 1\n",
    "        # 7:      1 1 1 1 0 0 0 0\n",
    "        mask_local_full = mask_global_full | far_past_full\n",
    "\n",
    "        row_slice = slice(pos_start, pos_end)\n",
    "        mask_global = mask_global_full[row_slice, :pos_end][None, None, :, :]\n",
    "        mask_local = mask_local_full[row_slice, :pos_end][None, None, :, :]\n",
    "        return mask_global, mask_local\n",
    "\n",
    "    def forward(self, input_ids, cache=None):\n",
    "        b, seq_len = input_ids.shape\n",
    "        x = self.tok_emb(input_ids)\n",
    "\n",
    "        if cache is not None:\n",
    "            pos_start = self.current_pos\n",
    "            pos_end = pos_start + seq_len\n",
    "            self.current_pos = pos_end\n",
    "            mask_global, mask_local = self.create_masks(\n",
    "                cur_len=seq_len, device=x.device, pos_start=pos_start, pos_end=pos_end\n",
    "            )\n",
    "        else:\n",
    "            pos_start = 0\n",
    "            mask_global, mask_local = self.create_masks(\n",
    "                cur_len=seq_len, device=x.device, pos_start=0, pos_end=seq_len\n",
    "            )\n",
    "\n",
    "        cos = self.cos\n",
    "        sin = self.sin\n",
    "\n",
    "        for i, block in enumerate(self.blocks):\n",
    "            blk_cache = cache.get(i) if cache is not None else None\n",
    "            x, new_blk_cache = block(\n",
    "                x,\n",
    "                mask_global=mask_global,\n",
    "                mask_local=mask_local,\n",
    "                cos=cos,\n",
    "                sin=sin,\n",
    "                start_pos=pos_start,\n",
    "                cache=blk_cache,\n",
    "            )\n",
    "\n",
    "            if cache is not None:\n",
    "                cache.update(i, new_blk_cache)\n",
    "\n",
    "        x = self.final_norm(x)\n",
    "        logits = self.out_head(x.to(self.cfg[\"dtype\"]))\n",
    "        return logits\n",
    "\n",
    "    def reset_kv_cache(self):\n",
    "        self.current_pos = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "4f5271e8-ff28-4aaa-bbb2-f73582e6d228",
   "metadata": {},
   "outputs": [],
   "source": [
    "class KVCache:\n",
    "    def __init__(self, n_layers):\n",
    "        self.cache = [None] * n_layers\n",
    "\n",
    "    def get(self, layer_idx):\n",
    "        return self.cache[layer_idx]\n",
    "\n",
    "    def update(self, layer_idx, value):\n",
    "        self.cache[layer_idx] = value\n",
    "\n",
    "    def get_all(self):\n",
    "        return self.cache\n",
    "\n",
    "    def reset(self):\n",
    "        for i in range(len(self.cache)):\n",
    "            self.cache[i] = None"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be2d201f-74ad-4d63-ab9c-601b00674a48",
   "metadata": {
    "id": "be2d201f-74ad-4d63-ab9c-601b00674a48"
   },
   "source": [
    "&nbsp;\n",
    "# 2. Initialize model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "caa142fa-b375-4e78-b392-2072ced666f3",
   "metadata": {
    "id": "caa142fa-b375-4e78-b392-2072ced666f3"
   },
   "outputs": [],
   "source": [
    "OLMO3_CONFIG_7B = {\n",
    "    \"vocab_size\": 100_278,\n",
    "    \"context_length\": 65_536,\n",
    "    \"emb_dim\": 4_096,\n",
    "    \"n_heads\": 32,\n",
    "    \"n_layers\": 32,\n",
    "    \"hidden_dim\": 11_008,\n",
    "    \"head_dim\": 128,\n",
    "    \"n_kv_heads\": 32,\n",
    "    \"attention_bias\": False,\n",
    "    \"attention_dropout\": 0.0,\n",
    "    \"sliding_window\": 4_096,\n",
    "    \"layer_types\": [\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"full_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"full_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"full_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"full_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"full_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"full_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"full_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"full_attention\",\n",
    "    ],\n",
    "    \"rope_base\": 500_000.0,\n",
    "    \"rope_attention_factor\": 1.2079441541679836,\n",
    "    \"rope_type\": \"yarn\",\n",
    "    \"rope_factor\": 8.0,\n",
    "    \"rope_orig_max\": 8_192,\n",
    "    \"rms_norm_eps\": 1e-6,\n",
    "    \"dtype\": torch.bfloat16,\n",
    "    \"eos_token_id\": 100_257,\n",
    "    \"pad_token_id\": 100_277,\n",
    "}\n",
    "\n",
    "OLMO3_CONFIG_32B = {\n",
    "    \"vocab_size\": 100_278,\n",
    "    \"context_length\": 65_536,\n",
    "    \"emb_dim\": 5_120,\n",
    "    \"n_heads\": 40,\n",
    "    \"n_layers\": 64,\n",
    "    \"hidden_dim\": 27_648,\n",
    "    \"head_dim\": 128,\n",
    "    \"n_kv_heads\": 8,\n",
    "    \"attention_bias\": False,\n",
    "    \"attention_dropout\": 0.0,\n",
    "    \"sliding_window\": 4_096,\n",
    "    \"layer_types\": [\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"full_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"full_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"full_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"full_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"full_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"full_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"full_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"full_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"full_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"full_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"full_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"full_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"full_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"full_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"sliding_attention\",\n",
    "        \"full_attention\",\n",
    "    ],\n",
    "    \"rope_base\": 500_000.0,\n",
    "    \"rope_attention_factor\": 1.2079441541679836,\n",
    "    \"rope_type\": \"yarn\",\n",
    "    \"rope_factor\": 8.0,\n",
    "    \"rope_orig_max\": 8_192,\n",
    "    \"rms_norm_eps\": 1e-6,\n",
    "    \"dtype\": torch.bfloat16,\n",
    "    \"eos_token_id\": 100_257,\n",
    "    \"pad_token_id\": 100_277,\n",
    "}\n",
    "\n",
    "OLMO3_CONFIG = OLMO3_CONFIG_32B if \"32B\" in USE_MODEL else OLMO3_CONFIG_7B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "156253fe-aacd-4da2-8f13-705f05c4b11e",
   "metadata": {
    "id": "156253fe-aacd-4da2-8f13-705f05c4b11e"
   },
   "outputs": [],
   "source": [
    "torch.manual_seed(123)\n",
    "model = Olmo3Model(OLMO3_CONFIG)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "eaf86265-4e9d-4024-9ed0-99076944e304",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Olmo3Model(\n",
       "  (tok_emb): Embedding(100278, 4096)\n",
       "  (blocks): ModuleList(\n",
       "    (0-31): 32 x TransformerBlock(\n",
       "      (att): GroupedQueryAttention(\n",
       "        (W_query): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "        (W_key): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "        (W_value): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "        (out_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "        (q_norm): RMSNorm()\n",
       "        (k_norm): RMSNorm()\n",
       "      )\n",
       "      (ff): FeedForward(\n",
       "        (fc1): Linear(in_features=4096, out_features=11008, bias=False)\n",
       "        (fc2): Linear(in_features=4096, out_features=11008, bias=False)\n",
       "        (fc3): Linear(in_features=11008, out_features=4096, bias=False)\n",
       "      )\n",
       "      (post_attention_layernorm): RMSNorm()\n",
       "      (post_feedforward_layernorm): RMSNorm()\n",
       "    )\n",
       "  )\n",
       "  (final_norm): RMSNorm()\n",
       "  (out_head): Linear(in_features=4096, out_features=100278, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "90aca91d-4bee-45ce-993a-4ec5393abe2b",
   "metadata": {},
   "source": [
    "- A quick check that the forward pass works before continuing:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "adf0a6b7-b688-42c9-966e-c223d34db99f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 0.3594, -0.6289, -0.2754,  ...,  1.1016,  0.4219,  0.0381],\n",
       "         [ 1.1719,  0.0283,  0.6055,  ...,  0.4863, -0.1953,  0.2246],\n",
       "         [ 0.4902, -0.0425,  0.6758,  ...,  0.3730, -0.5781, -0.1670]]],\n",
       "       dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model(torch.tensor([1, 2, 3]).unsqueeze(0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "31f12baf-f79b-499f-85c0-51328a6a20f5",
   "metadata": {
    "id": "31f12baf-f79b-499f-85c0-51328a6a20f5"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/rasbt/jupyterlab/reasoning/.venv/lib/python3.12/site-packages/torch/cuda/__init__.py:283: UserWarning: \n",
      "    Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.\n",
      "    Minimum and Maximum cuda capability supported by this version of PyTorch is\n",
      "    (8.0) - (12.0)\n",
      "    \n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda\")\n",
    "elif torch.backends.mps.is_available():\n",
    "    device = torch.device(\"mps\")\n",
    "else:\n",
    "    device = torch.device(\"cpu\")\n",
    "\n",
    "model.to(device);"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c172f89f-d301-439f-b809-46169e5f5945",
   "metadata": {
    "id": "c172f89f-d301-439f-b809-46169e5f5945"
   },
   "source": [
    "&nbsp;\n",
    "# 4. Load pretrained weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "75166128-5899-4995-9b88-9672e135650e",
   "metadata": {
    "id": "75166128-5899-4995-9b88-9672e135650e"
   },
   "outputs": [],
   "source": [
    "def load_weights_into_olmo(model, param_config, params):\n",
    "    def assign(left, right, tensor_name=\"unknown\"):\n",
    "        if left.shape != right.shape:\n",
    "            raise ValueError(\n",
    "                f\"Shape mismatch in tensor '{tensor_name}'. \"\n",
    "                f\"Left: {left.shape}, Right: {right.shape}\"\n",
    "            )\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            if isinstance(right, torch.Tensor):\n",
    "                left.copy_(right)\n",
    "            else:\n",
    "                left.copy_(torch.as_tensor(right, dtype=left.dtype, device=left.device))\n",
    "        \n",
    "        return left\n",
    "\n",
    "    # Token embedding\n",
    "    if \"model.embed_tokens.weight\" in params:\n",
    "        model.tok_emb.weight = assign(\n",
    "            model.tok_emb.weight,\n",
    "            params[\"model.embed_tokens.weight\"],\n",
    "            \"model.embed_tokens.weight\",\n",
    "        )\n",
    "\n",
    "    for l in range(param_config[\"n_layers\"]):\n",
    "        block = model.blocks[l]\n",
    "        att = block.att\n",
    "\n",
    "        # Q, K, V projections\n",
    "        att.W_query.weight = assign(\n",
    "            att.W_query.weight,\n",
    "            params[f\"model.layers.{l}.self_attn.q_proj.weight\"],\n",
    "            f\"model.layers.{l}.self_attn.q_proj.weight\",\n",
    "        )\n",
    "        att.W_key.weight = assign(\n",
    "            att.W_key.weight,\n",
    "            params[f\"model.layers.{l}.self_attn.k_proj.weight\"],\n",
    "            f\"model.layers.{l}.self_attn.k_proj.weight\",\n",
    "        )\n",
    "        att.W_value.weight = assign(\n",
    "            att.W_value.weight,\n",
    "            params[f\"model.layers.{l}.self_attn.v_proj.weight\"],\n",
    "            f\"model.layers.{l}.self_attn.v_proj.weight\",\n",
    "        )\n",
    "\n",
    "        # Output projection\n",
    "        att.out_proj.weight = assign(\n",
    "            att.out_proj.weight,\n",
    "            params[f\"model.layers.{l}.self_attn.o_proj.weight\"],\n",
    "            f\"model.layers.{l}.self_attn.o_proj.weight\",\n",
    "        )\n",
    "\n",
    "        # QK norms\n",
    "        att.q_norm.weight = assign(\n",
    "            att.q_norm.weight,\n",
    "            params[f\"model.layers.{l}.self_attn.q_norm.weight\"],\n",
    "            f\"model.layers.{l}.self_attn.q_norm.weight\",\n",
    "        )\n",
    "        att.k_norm.weight = assign(\n",
    "            att.k_norm.weight,\n",
    "            params[f\"model.layers.{l}.self_attn.k_norm.weight\"],\n",
    "            f\"model.layers.{l}.self_attn.k_norm.weight\",\n",
    "        )\n",
    "\n",
    "        # Feedforward weights\n",
    "        block.ff.fc1.weight = assign(\n",
    "            block.ff.fc1.weight,\n",
    "            params[f\"model.layers.{l}.mlp.gate_proj.weight\"],\n",
    "            f\"model.layers.{l}.mlp.gate_proj.weight\",\n",
    "        )\n",
    "        block.ff.fc2.weight = assign(\n",
    "            block.ff.fc2.weight,\n",
    "            params[f\"model.layers.{l}.mlp.up_proj.weight\"],\n",
    "            f\"model.layers.{l}.mlp.up_proj.weight\",\n",
    "        )\n",
    "        block.ff.fc3.weight = assign(\n",
    "            block.ff.fc3.weight,\n",
    "            params[f\"model.layers.{l}.mlp.down_proj.weight\"],\n",
    "            f\"model.layers.{l}.mlp.down_proj.weight\",\n",
    "        )\n",
    "\n",
    "        # Post-attention and post norms\n",
    "        block.post_attention_layernorm.weight = assign(\n",
    "            block.post_attention_layernorm.weight,\n",
    "            params[f\"model.layers.{l}.post_attention_layernorm.weight\"],\n",
    "            f\"model.layers.{l}.post_attention_layernorm.weight\",\n",
    "        )\n",
    "        block.post_feedforward_layernorm.weight = assign(\n",
    "            block.post_feedforward_layernorm.weight,\n",
    "            params[f\"model.layers.{l}.post_feedforward_layernorm.weight\"],\n",
    "            f\"model.layers.{l}.post_feedforward_layernorm.weight\",\n",
    "        )\n",
    "\n",
    "    # Final normalization and output head\n",
    "    if \"model.norm.weight\" in params:\n",
    "        model.final_norm.weight = assign(\n",
    "            model.final_norm.weight,\n",
    "            params[\"model.norm.weight\"],\n",
    "            \"model.norm.weight\",\n",
    "        )\n",
    "\n",
    "    if \"lm_head.weight\" in params:\n",
    "        model.out_head.weight = assign(\n",
    "            model.out_head.weight,\n",
    "            params[\"lm_head.weight\"],\n",
    "            \"lm_head.weight\",\n",
    "        )\n",
    "    else:\n",
    "        model.out_head.weight = model.tok_emb.weight\n",
    "        print(\"Model uses weight tying.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 17,
     "referenced_widgets": [
      "9881b6995c3f49dc89e6992fd9ab660b",
      "17a3174e65c54476b2e0d1faf8f011ca",
      "1bbf2e62c0754d1593beb4105a7f1ac1",
      "b82112e1dec645d98aa1c1ba64abcb61",
      "271e2bd6a35e4a8b92de8697f7c0be5f",
      "90a79523187446dfa692723b2e5833a7",
      "431ffb83b8c14bf182f0430e07ea6154",
      "a8f1b72a33dd4b548de23fbd95e0da18",
      "25cc36132d384189acfbecc59483134b",
      "bfd06423ad544218968648016e731a46",
      "d029630b63ff44cf807ade428d2eb421"
     ]
    },
    "id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
    "outputId": "55b2f28c-142f-4698-9d23-d27456d3ed6d"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0fcdf72bf5b646d39bf4ed84faeb3302",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Fetching 14 files:   0%|          | 0/14 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import json\n",
    "import os\n",
    "from pathlib import Path\n",
    "from safetensors.torch import load_file\n",
    "from huggingface_hub import snapshot_download\n",
    "\n",
    "repo_id = f\"allenai/{USE_MODEL}\"\n",
    "local_dir = Path(repo_id).parts[-1]\n",
    "\n",
    "repo_dir = snapshot_download(repo_id=repo_id, local_dir=local_dir)\n",
    "index_path = os.path.join(repo_dir, \"model.safetensors.index.json\")\n",
    "with open(index_path, \"r\") as f:\n",
    "    index = json.load(f)\n",
    "\n",
    "weights_dict = {}\n",
    "for filename in sorted(set(index[\"weight_map\"].values())):\n",
    "    shard_path = os.path.join(repo_dir, filename)\n",
    "    shard = load_file(shard_path)\n",
    "    weights_dict.update(shard)\n",
    "\n",
    "load_weights_into_olmo(model, OLMO3_CONFIG, weights_dict)\n",
    "model.to(device)\n",
    "del weights_dict"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6b345491-3510-4397-92d3-cd0a3fa3deee",
   "metadata": {},
   "source": [
    "&nbsp;\n",
    "# 4. Load tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "b68ab489-48e5-471e-a814-56cda2d60f81",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tokenizers import Tokenizer\n",
    "from huggingface_hub import hf_hub_download\n",
    "\n",
    "\n",
    "class OlmoTokenizer:\n",
    "    def __init__(self, tokenizer_file_path, eos_token_id, pad_token_id):\n",
    "        tok_file = Path(tokenizer_file_path)\n",
    "        self._tok = Tokenizer.from_file(str(tok_file))\n",
    "        eos_from_tok = (\n",
    "            self._tok.token_to_id(\"<|endoftext|>\")\n",
    "            or self._tok.token_to_id(\"<end_of_turn>\")\n",
    "        )\n",
    "        self.eos_token_id = eos_from_tok if eos_from_tok is not None else eos_token_id\n",
    "        pad_from_tok = (\n",
    "            self._tok.token_to_id(\"<|pad|>\")\n",
    "            or self._tok.token_to_id(\"<pad>\")\n",
    "        )\n",
    "        self.pad_token_id = pad_from_tok if pad_from_tok is not None else pad_token_id\n",
    "\n",
    "    def encode(self, text):\n",
    "        return self._tok.encode(text).ids\n",
    "\n",
    "    def decode(self, ids):\n",
    "        return self._tok.decode(ids, skip_special_tokens=False)\n",
    "\n",
    "\n",
    "def apply_chat_template(user_text):\n",
    "    return (\n",
    "        \"<|im_start|>user\\n\"\n",
    "        f\"{user_text}\\n\"\n",
    "        \"<|im_end|>\\n\"\n",
    "        \"<|im_start|>assistant\\n\"\n",
    "    )\n",
    "\n",
    "\n",
    "tokenizer_file_path = os.path.join(local_dir, \"tokenizer.json\")\n",
    "if not os.path.exists(tokenizer_file_path):\n",
    "    try:\n",
    "        tokenizer_file_path = hf_hub_download(repo_id=repo_id, filename=\"tokenizer.json\", local_dir=local_dir)\n",
    "    except Exception as e:\n",
    "        print(f\"Warning: failed to download tokenizer.json: {e}\")\n",
    "        tokenizer_file_path = \"tokenizer.json\"\n",
    "\n",
    "tokenizer = OlmoTokenizer(\n",
    "    tokenizer_file_path=tokenizer_file_path,\n",
    "    eos_token_id=OLMO3_CONFIG[\"eos_token_id\"],\n",
    "    pad_token_id=OLMO3_CONFIG[\"pad_token_id\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "7b6df8bc-7308-468e-93ce-2d5529ea7866",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'<|im_start|>user\\nGive me a short intro to large language models in 3 sentences.\\n<|im_end|>\\n<|im_start|>assistant\\n'"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prompt = apply_chat_template(\"Give me a short intro to large language models in 3 sentences.\")\n",
    "\n",
    "input_token_ids = tokenizer.encode(prompt)\n",
    "text = tokenizer.decode(input_token_ids)\n",
    "text"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57d07df1-4401-4792-b549-7c4cc5632323",
   "metadata": {
    "id": "57d07df1-4401-4792-b549-7c4cc5632323"
   },
   "source": [
    "&nbsp;\n",
    "# 5. Generate text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5",
   "metadata": {
    "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5"
   },
   "outputs": [],
   "source": [
    "def generate_text_basic_stream(model, token_ids, max_new_tokens, eos_token_id=None, context_size=None):\n",
    "\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        cache = KVCache(n_layers=model.cfg[\"n_layers\"])\n",
    "        model.reset_kv_cache()\n",
    "\n",
    "        logits = model(token_ids, cache=cache)\n",
    "\n",
    "        for _ in range(max_new_tokens):\n",
    "            next_token = torch.argmax(logits[:, -1], dim=-1, keepdim=True)\n",
    "\n",
    "            if (eos_token_id is not None\n",
    "                   and torch.all(next_token == eos_token_id)):\n",
    "               break\n",
    "\n",
    "            yield next_token\n",
    "\n",
    "            token_ids = torch.cat([token_ids, next_token], dim=1)\n",
    "\n",
    "            logits = model(next_token, cache=cache)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d",
   "metadata": {
    "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sure! Here’s a brief introduction to large language models:  \n",
      "Large models are advanced AI systems trained to process vast neural networks capable of understanding and generating text, learning from vast amounts of data, learning language, performing diverse tasks, assisting in many applications, and adapting various tasks.\n",
      "\n",
      "GPU memory used: 13.71 GB\n"
     ]
    }
   ],
   "source": [
    "input_token_ids_tensor = torch.tensor(input_token_ids, device=device).unsqueeze(0)\n",
    "\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.reset_peak_memory_stats()\n",
    "\n",
    "\n",
    "for token in generate_text_basic_stream(\n",
    "    model=model,\n",
    "    token_ids=input_token_ids_tensor,\n",
    "    max_new_tokens=500,\n",
    "    eos_token_id=tokenizer.eos_token_id\n",
    "):\n",
    "    token_id = token.squeeze(0).tolist()\n",
    "    print(\n",
    "        tokenizer.decode(token_id),\n",
    "        end=\"\",\n",
    "        flush=True\n",
    "    )\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    def gpu_gb(x):\n",
    "        return f\"{x / 1024 / 1024 / 1024:.2f} GB\"\n",
    "    \n",
    "    print(f\"\\n\\nGPU memory used: {gpu_gb(torch.cuda.max_memory_allocated())}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "549324d6-5c71-4147-ae21-2e67675faa3d",
   "metadata": {
    "id": "549324d6-5c71-4147-ae21-2e67675faa3d"
   },
   "source": [
    "&nbsp;\n",
    "# What's next?"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c",
   "metadata": {
    "id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c"
   },
   "source": [
    "- For those interested in a comprehensive guide on building a large language model from scratch and gaining a deeper understanding of its mechanics, you might like my [Build a Large Language Model (From Scratch)](http://mng.bz/orYv)\n",
    "\n",
    "<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "A100",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
